package com.example.demo.config.db.interceptor;

import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.example.demo.common.util.SM4Encryptor;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;

/**
 * 自定义 Mybatis 数据安全处理拦截器类。
 */
@Slf4j
@Component
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
public class CustomizeMybatisDataSecurityHandlerInterceptor implements Interceptor {
//    /**
//     * 加密服务管理对象。
//     */
//    @Autowired
//    private ApiEncryptService mApiEncryptService;
    /**
     * 查询条件 key 名称。
     */
    private static final String CRITERIA_KEY_NAME = "oredCriteria";
    /**
     * 获取条件集合 key 名称。
     */
    private static final String CRITERIA_GET_KEY_NAME = "getCriteria";
    /**
     * 获取条件集合下的条件 key 名称。
     */
    private static final String CRITERIA_GET_CONDITION_KEY_NAME = "getCondition";
    /**
     * 条件集合下的条件值 key 名称。
     */
    private static final String CRITERIA_VALUE_KEY_NAME = "value";
    /**
     * 获取条件集合下的条件值 key 名称。
     */
    private static final String CRITERIA_GET_VALUE_KEY_NAME = "getValue";
    /**
     * 方法对象缓存（便于下次快速获取对象）。<br/>
     * key：Mapper 中的方法（包括包路径）。<br/>
     * value：对应的方法对象。
     */
    private static final Map<String, Method> methodCache = new HashMap<>();
    /**
     * 需要排除的数据类型集合。
     */
    private static final List<Class> excludedDataTypeList = Arrays.asList(String.class, Integer.class, Boolean.class, Double.class, Float.class, Long.class, Short.class, Byte.class, Character.class);
 
    /***
     * 实现 intercept 方法，该方法将传递 Invocation 对象作为参数。可以使用该对象调用原始方法，并对其返回值进行处理。
     * @param invocation
     * @return
     * @throws Throwable
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] _args = invocation.getArgs();
        if (_args != null && _args.length > 1) {
            MappedStatement _ms = (MappedStatement) _args[0];
            Object _parameter = _args[1]; // 可能是实体对象、Map 对象
            if (_ms != null && _parameter != null) {
                String _orgMethodId = _ms.getId();
                // 针对 Example 的查询方式，会自动生成 xxxxByExample_COUNT 方法，故这里进行处理
                String _keyString = "ByExample_COUNT";
                if (_orgMethodId.endsWith(_keyString)) {
                    _keyString = "_COUNT";
                    _orgMethodId = _orgMethodId.substring(0, _orgMethodId.length() - _keyString.length());
                }
                Method _orgMethod = methodCache.get(_orgMethodId);
                DataSecurity _orgMethodDataSecurity = null;
                if (_orgMethod == null) {
                    String _mapperClassName = _orgMethodId.substring(0, _orgMethodId.lastIndexOf("."));
                    Class<?> _mapperClass = Class.forName(_mapperClassName);
                    if (_mapperClass != null) {
                        String _methodName = _orgMethodId.substring(_orgMethodId.lastIndexOf(".") + 1);
                        _orgMethod = Arrays.stream(_mapperClass.getMethods()).filter(method -> method.getName().equalsIgnoreCase(_methodName)).findFirst().orElse(null);
                        if (_orgMethod != null) {
                            methodCache.put(_orgMethodId, _orgMethod);
                        }
                    }
                }
                if (_orgMethod != null && _orgMethod.isAnnotationPresent(DataSecurity.class)) {
                    _orgMethodDataSecurity = _orgMethod.getAnnotation(DataSecurity.class);
                }
                switch (_ms.getSqlCommandType()) {
                    case INSERT: // 处理插入逻辑
                    case UPDATE: // 处理更新逻辑
                        // 是否为批量插入，真表示是，反之不是。
//                        boolean _isBatchInsert = _parameter instanceof Map;
                        String methodName=_orgMethod.getName();
                        boolean _isBatchInsert =methodName.toUpperCase().contains("BATCH");
                        // log.info("是否为批量插入：" + _isBatchInsert);
                        if (_isBatchInsert) {
                            Map<String, Object> _tempMap = (Map<String, Object>) _parameter;
                            String arrayKey="list";
                            try{
                                if(_tempMap.get("collection")!=null){
                                    arrayKey="collection" ;
                                }
                            }catch (Exception e){
                                log.error(e.getMessage());
                            }
                            List<?> _list = (List<?>) Optional.ofNullable(_tempMap.get(arrayKey)).orElse(new ArrayList<Map<String, Object>>());
                            if (CollectionUtil.isNotEmpty(_list)) {
                                // 克隆新对象，防止影响原始数据
                                _list = BeanUtil.copyToList(_list, _list.get(0).getClass());

                                for (Object _entity : _list) {
                                    this.encryptField(_entity, _orgMethodDataSecurity);
                                }
                            }
                            for (Map.Entry<String, Object> entry : _tempMap.entrySet()) {
                                if (entry.getValue() instanceof ArrayList) {
                                    // 更新原来的集合
                                    _tempMap.put(entry.getKey(), _list);
                                }
                            }
                        } else {
                            // 克隆新对象，防止影响原始数据
                            _parameter = BeanUtil.copyProperties(_parameter, _parameter.getClass());
                            _args[1] = _parameter;
                            this.encryptField(_parameter, _orgMethodDataSecurity);
                        }
                        break;
                    case SELECT: // 处理查询逻辑
                        if (_orgMethod != null) {
                            // 克隆新对象，防止影响原始数据
                            _parameter = this.cloneObject(_parameter);
                            this.encryptField(_parameter, _orgMethodDataSecurity);
                            _args[1] = _parameter;
                        }
                        Object _queryResult = invocation.proceed();
                        if (_orgMethod != null) {
                            // 解密，防止影响下一次的查询
                            this.decryptField(_parameter, _orgMethodDataSecurity);
                            _args[1] = _parameter;
                        }
                        boolean _isList = _queryResult instanceof ArrayList;
                        // log.info("查询结果是不是集合：" + _isList);
                        if (_isList) {
                            List<?> _list = (List<?>) _queryResult;
                            for (Object _entity : _list) {
                                this.decryptField(_entity, _orgMethodDataSecurity);
                            }
                        } else {
                            // 经测试，发现查询单条数据，也是返回集合，故这里先不进行处理。
                        }
                        return _queryResult;
                }
            }
        }
        return invocation.proceed();
    }
 
    /**
     * 克隆对象。
     *
     * @param source 源对象
     * @return 返回克隆后的对象。
     */
    private Object cloneObject(Object source) {
        Object _result = source;
        if (_result != null) {
            try {
                Class<?> _sourceClazz = _result.getClass();
                if (!_sourceClazz.getSimpleName().endsWith("Example") && this.excludedDataTypeList.indexOf(_sourceClazz) == -1) {
                    // 不需要排除的数据类型，则进行克隆
                    if (_sourceClazz == ArrayList.class || _sourceClazz == List.class) {
                        List<?> _list = (List<?>) Optional.of(_result).orElse(new ArrayList<>());
                        if (CollectionUtil.isNotEmpty(_list)) {
                            _result = BeanUtil.copyToList(_list, _list.get(0).getClass());
                        }
                    } else {
                        _result = BeanUtil.copyProperties(_result, _sourceClazz);
                    }
                }
            } catch (Exception ex) {
            }
        }
        return _result;
    }
 
    /**
     * 给字段进行加密。
     *
     * @param entity             待加密的实体对象
     * @param methodDataSecurity 方法上被注解的数据安全对象
     */
    private void encryptField(Object entity, DataSecurity methodDataSecurity) {
        this.updateField(entity, methodDataSecurity, 0);
    }
 
    /**
     * 给字段进行解密。
     *
     * @param entity             待解密的实体对象
     * @param methodDataSecurity 方法上被注解的数据安全对象
     */
    private void decryptField(Object entity, DataSecurity methodDataSecurity) {
        this.updateField(entity, methodDataSecurity, 1);
    }
 
    /**
     * 更新字段属性值。
     *
     * @param entity             待更新的实体对象
     * @param methodDataSecurity 方法上被注解的数据安全对象
     * @param type               更新类型，0：表示加密，1：表示解密
     */
    private void updateField(Object entity, DataSecurity methodDataSecurity, int type) {
        if (entity == null) {
            return;
        }
        if (entity instanceof Map) {
            Map<String, Object> _tempMap = (Map<String, Object>) entity;
            Set<Map.Entry<String, Object>> _entrys = _tempMap.entrySet();
            List<String> _tempDataSecurityValueList = new ArrayList<>();
            if (methodDataSecurity != null) {
                _tempDataSecurityValueList = Arrays.asList(methodDataSecurity.value());
            }
            for (Map.Entry<String, Object> _entry : _entrys) {
                if (_entry.getKey().equalsIgnoreCase(CRITERIA_KEY_NAME)) {
                    this.handlerCriteria(_entry.getValue(), methodDataSecurity, type);
                } else {
                    if (_tempDataSecurityValueList.contains(_entry.getKey())) {
                        _tempMap.put(_entry.getKey(), this.formatValue(methodDataSecurity, _entry.getValue().toString(), type));
                    } else {
                        this.updateField(_entry.getValue(), methodDataSecurity, type);
                    }
                }
            }
            return;
        }
        if (entity instanceof ArrayList) {
            List<?> _list = (List<?>) entity;
            for (Object _item : _list) {
                this.updateField(_item, methodDataSecurity, type);
            }
            return;
        }
        Class<?> _clazz = entity.getClass();
        Field[] _fields = _clazz.getDeclaredFields();
        for (Field _field : _fields) {
            _field.setAccessible(true);
            if (methodDataSecurity != null && _field.getName().equalsIgnoreCase(CRITERIA_KEY_NAME)) {
                // 针对 Example 的查询方式进行特殊处理
                try {
                    this.handlerCriteria(_field.get(entity), methodDataSecurity, type);
                } catch (IllegalAccessException e) {
                    // throw new RuntimeException(e);
                }
            } else if (_field.isAnnotationPresent(DataSecurity.class)) {
                // 获取要加密的字段值
                DataSecurity _mDataSecurity = _field.getAnnotation(DataSecurity.class);
                try {
                    Object _value = _field.get(entity);
                    if (_value != null) {
                        Object _tempValue = this.formatValue(_mDataSecurity, _value.toString(), type);
                        if (_tempValue != null) {
                            // 将加解密后的数据设置回去
                            _field.set(entity, _tempValue);
                        }
                    }
                } catch (IllegalAccessException e) {
                    // throw new RuntimeException(e);
                }
            }
        }
    }
 
    /**
     * 处理 Example 查询数据。
     *
     * @param entity       待更新的实体对象
     * @param dataSecurity 方法上被注解的数据安全对象
     * @param type         更新类型，0：表示加密，1：表示解密
     */
    private void handlerCriteria(Object entity, DataSecurity dataSecurity, int type) {
        if (entity == null || dataSecurity == null) {
            return;
        }
        // 针对 Example 的查询方式进行特殊处理
        List<?> _list = (List<?>) entity;
        // log.info("{}", _list);
        List<String> _tempDataSecurityValueList = Arrays.asList(dataSecurity.value());
        if (CollectionUtil.isNotEmpty(_list) && CollectionUtil.isNotEmpty(_tempDataSecurityValueList)) {
            _list = ReflectUtil.invoke(_list.get(0), CRITERIA_GET_KEY_NAME);
            for (Object _item : _list) {
                try {
                    String _tempCondition = ReflectUtil.invoke(_item, CRITERIA_GET_CONDITION_KEY_NAME);
                    // log.info("{}", _tempCondition);
                    String _keyString = " =";
                    if (!_tempCondition.trim().endsWith(_keyString)) {
                        continue;
                    }
                    _tempCondition = StrUtil.toCamelCase(_tempCondition.substring(0, _tempCondition.length() - _keyString.length()));
                    if (!_tempDataSecurityValueList.contains(_tempCondition)) {
                        continue;
                    }
                    Object _tempValue = ReflectUtil.invoke(_item, CRITERIA_GET_VALUE_KEY_NAME);
                    if (!(_tempValue instanceof String)) {
                        continue;
                    }
                    // log.info("{} {}", _tempCondition, _tempValue);
                    _tempValue = this.formatValue(dataSecurity, _tempValue.toString(), type);
                    if (_tempValue != null) {
                        try {
                            Field _tempField = ReflectUtil.getField(_item.getClass(), CRITERIA_VALUE_KEY_NAME);
                            _tempField.setAccessible(true);
                            _tempField.set(_item, _tempValue);
                        } catch (IllegalAccessException e) {
                            // throw new RuntimeException(e);
                        }
                    }
                    // log.info("{} {}", _tempCondition, _tempValue);
                } catch (Exception ex) {
                }
            }
        }
    }
 
    /**
     * 格式化值。
     *
     * @param dataSecurity 数据安全对象
     * @param orgValue     原数据值
     * @param type         更新类型，0：表示加密，1：表示解密
     * @return 返回格式化后的值。
     */
    private String formatValue(DataSecurity dataSecurity, String orgValue, int type) {
        if (dataSecurity == null || orgValue == null) {
            return orgValue;
        }
        String _tempValue = null;
        try {
            // 使用加解密算法进行加解密
            switch (type) {
                case 0: // 加密
                    _tempValue = SM4Encryptor.encryptStr(orgValue);
//                    switch (dataSecurity.algorithm()) {
//                        case BASE64:
//                            _tempValue = Base64Utils.encodeToString(orgValue.getBytes(StandardCharsets.UTF_8));
//                            break;
//                        case SM2:
//                            _tempValue = this.mApiEncryptService.encrypt2Data(orgValue);
//                            break;
//                        case SM4:
//                            _tempValue = this.mApiEncryptService.encrypt4Data(orgValue);
//                            break;
//                    }
                    break;
                case 1: // 解密
//                    switch (dataSecurity.algorithm()) {
//                        case BASE64:
//                            _tempValue = new String(Base64Utils.decodeFromString(orgValue), StandardCharsets.UTF_8);
//                            break;
//                        case SM2:
//                            _tempValue = this.mApiEncryptService.decrypt2Data(orgValue);
//                            break;
//                        case SM4:
//                            _tempValue = this.mApiEncryptService.decrypt4Data(orgValue);
//                            break;
//                    }
                    _tempValue = SM4Encryptor.decryptStr(orgValue);
                    break;
            }
        } catch (Exception ex) {
        }
        if (_tempValue != null) {
            return _tempValue;
        }
        return orgValue;
    }
}